%% JRD model code
%clear all;
rng('shuffle');

% Specify number of participants to simulate
numParticipants = 16;

% Create a matrix for reference directions and specify num of objects and
% IDs
%refDirection = [0,1; 1,0; 0,-1; -1,0];
%refDirection = [0,1; 1,0; 0,-1; -1,0; 1,-1];
refDirection = [1,-1];

% Change refDir to 0 if oneview group; change layout to 0 if layout 2
refDir = 0;
layout = 0;

headingConditions = [0,45,90,135,180,225,270,315];
pointConditions = [45,90,135,225,270,315];

numruns = 1000;
for z=1:numruns
for p=1:numParticipants
% Parameter: noisy object location
if refDir == 1
    a = optTwo(p,1);
    b = optTwo(p,2);
else
    a = optOne(p,1);
    b = optOne(p,2);
end

l = 1;
q = 20;

% Create a matrix containing the locations of all objects relative to an
% origin (obj id is the index)
if layout == 1
    locRelOrigin = [-1,1; 0,1; 1,1; -1,0; 0,0; 1,0; 0,-1];
else
    locRelOrigin = [0,1; -1,0; 0,0; 1,0; -1,-1; 0,-1; 1,-1];
end

% Create cell array; each index is an item containing a n by 2 matrix of 
% vectors representing the distances of each item from that location
% relative to the reference axes. The encoded object locations are then
% computed from those.
for i=1:length(locRelOrigin(:,1))
    for j=1:length(locRelOrigin(:,1))
        locRelObj{i}(j,1) = locRelOrigin(j,1) - locRelOrigin(i,1);
        locRelObj{i}(j,2) = locRelOrigin(j,2) - locRelOrigin(i,2);
    end
end

% Now create an array of test probes (first index = standing location;
% second index = imagined heading; third index = pointing direction) using
% combination algorithm where order matters
n = length(locRelOrigin(:,1)); k = 3;
nk = nchoosek(1:n,k);
testProbe=zeros(0,k);
for i=1:size(nk,1)
    pi = perms(nk(i,:));
    testProbe = unique([testProbe; pi],'rows');
end

% Organize test probes by heading and allocentric pointing direction (round
% to get rid of strange decimals)
for i=1:length(testProbe(:,1))
    if thetaFunc([1,0],locRelObj{testProbe(i,1)}(testProbe(i,2),:)) <= ... 
            thetaFunc([-1,0],locRelObj{testProbe(i,1)}(testProbe(i,2),:))
        testProbe(i,4) = round(thetaFunc([0,1],locRelObj{testProbe(i,1)}(testProbe(i,2),:)));
    else
        testProbe(i,4) = round(180 + thetaFunc([0,-1],locRelObj{testProbe(i,1)}(testProbe(i,2),:)));
    end
    if thetaFunc([1,0],locRelObj{testProbe(i,1)}(testProbe(i,3),:)) <= ... 
            thetaFunc([-1,0],locRelObj{testProbe(i,1)}(testProbe(i,3),:))
        testProbe(i,5) = round(thetaFunc([0,1],locRelObj{testProbe(i,1)}(testProbe(i,3),:)));
    else
        testProbe(i,5) = round(180 + thetaFunc([0,-1],locRelObj{testProbe(i,1)}(testProbe(i,3),:)));
    end
end

    
    % Randomize sample of test probes
    testProbe = testProbe(randperm(size(testProbe,1)),:);
    pointResponse = [];
    
    % Get 48 test probes, a trial for each heading and pointing direction
    count = zeros(length(headingConditions),length(pointConditions));
    newProbeCount = 0;
    for i=1:length(testProbe(:,1))
        for j=1:length(headingConditions)
            if round(testProbe(i,4)) == round(headingConditions(j)) && sum(count(j,:)) < 8
                for k=1:length(pointConditions)
                    if (testProbe(i,4) < testProbe(i,5) && testProbe(i,5) - testProbe(i,4) == pointConditions(k) && count(j,k) < 1) || ... 
                            (testProbe(i,4) > testProbe(i,5) && 360 - (testProbe(i,4) - testProbe(i,5)) == pointConditions(k) && count(j,k) < 1)
                        count(j,k) = count(j,k) + 1;
                        newProbeCount = newProbeCount + 1;
                        newTestProbe(newProbeCount,:) = testProbe(i,:);
                    end
                end
            end
        end
    end
    
    % Randomize new test probes
    newTestProbe = newTestProbe(randperm(size(newTestProbe,1)),:);
            
    % Objects with noisy encoding
    for i=1:length(locRelOrigin(:,1))
        for j=1:length(locRelOrigin(:,1))
            for k=1:length(refDirection(:,1))
                angleHolder(k) = round(thetaFunc(refDirection(k,:),locRelObj{i}(j,:)));           
            end
            [M,I] = min(angleHolder);
            locEncoded{i}(j,1) = locRelObj{i}(j,1) + normrnd(0,(a/(1+b*exp(-(l)*thetaFunc(refDirection(I,:),locRelObj{i}(j,:))))));
            locEncoded{i}(j,2) = locRelObj{i}(j,2) + normrnd(0,(a/(1+b*exp(-(l)*thetaFunc(refDirection(I,:),locRelObj{i}(j,:))))));
            if i > 1
                while i - j > 0    
                    locEncoded{i}(i-j,:) = -(locEncoded{i-j}(i,:));
                    j = j + 1;
                end
            end
        end
    end

    % Now an algorithm to cycle through the test probes and perform the JRD
    % task
    for i=1:length(newTestProbe(:,1))

        % Define standing position, imagined heading, and pointing direction
        standpos = newTestProbe(i,1);
        imagheading = newTestProbe(i,2);
        pointobj = newTestProbe(i,3);

        % Create arrays for heading condition and correct pointing direction
        if thetaFunc([1,0],locRelObj{standpos}(imagheading,:)) <= ... 
                thetaFunc([-1,0],locRelObj{standpos}(imagheading,:))        
            headingCond(i) = thetaFunc([0,1],locRelObj{standpos}(imagheading,:));
        else
            headingCond(i) = 180 + thetaFunc([0,-1],locRelObj{standpos}(imagheading,:));
        end
        % Get rid of strange decimals
        headingCond(i) = round(headingCond(i));
        
        % Get correct pointing angle
        correctPointAngle(i) = atanThetaFunc(locRelObj{standpos}(imagheading,:), ...
            locRelObj{standpos}(pointobj,:));
        % Get rid of strange decimals
        correctPointAngle(i) = round(correctPointAngle(i));

        % Create array for pointing condition
        if thetaFunc([1,0],locRelObj{standpos}(pointobj,:)) <= ... 
                thetaFunc([-1,0],locRelObj{standpos}(pointobj,:))        
            pointingCond(i) = thetaFunc([0,1],locRelObj{standpos}(pointobj,:));
        else
            pointingCond(i) = 180 + thetaFunc([0,-1],locRelObj{standpos}(pointobj,:));
        end
        % Get rid of strange decimals
        pointingCond(i) = round(pointingCond(i));

        % Check if imagined heading is consistent with reference direction
        for j=1:length(refDirection(:,1))
            if round(thetaFunc(locEncoded{standpos}(imagheading,:),refDirection(j,:))) <= q
                pointResponse(i) = atanThetaFunc(refDirection(j,:), ... 
                    locEncoded{standpos}(pointobj,:));
                break
            end
        end

        % If imagined direction is not in line, need to do additional
        % computaions (vecter addition)
        if length(pointResponse) ~= i
            angleVector = locEncoded{standpos}(imagheading,:) + ... 
                locEncoded{imagheading}(pointobj,:);
            pointResponse(i) = atanThetaFunc(locEncoded{standpos}(imagheading,:), ... 
                angleVector);
        end  
        
    end

    % Gather data for pointing errors by heading condition    
    count = zeros(8,1);
    for i=1:length(pointResponse)
        if headingCond(i) == 0
            count(1) = count(1) + 1;
            pointErrorByHeading{1}(count(1)) = abs(pointResponse(i) - correctPointAngle(i));
            if pointErrorByHeading{1}(count(1)) > 180
                pointErrorByHeading{1}(count(1)) = 360 - pointErrorByHeading{1}(count(1));
            end
        elseif headingCond(i) == 45
            count(2) = count(2) + 1;
            pointErrorByHeading{2}(count(2)) = abs(pointResponse(i) - correctPointAngle(i));
            if pointErrorByHeading{2}(count(2)) > 180
                pointErrorByHeading{2}(count(2)) = 360 - pointErrorByHeading{2}(count(2));
            end
        elseif headingCond(i) == 90
            count(3) = count(3) + 1;
            pointErrorByHeading{3}(count(3)) = abs(pointResponse(i) - correctPointAngle(i));
            if pointErrorByHeading{3}(count(3)) > 180
                pointErrorByHeading{3}(count(3)) = 360 - pointErrorByHeading{3}(count(3));
            end
        elseif headingCond(i) == 135
            count(4) = count(4) + 1;
            pointErrorByHeading{4}(count(4)) = abs(pointResponse(i) - correctPointAngle(i));
            if pointErrorByHeading{4}(count(4)) > 180
                pointErrorByHeading{4}(count(4)) = 360 - pointErrorByHeading{4}(count(4));
            end
        elseif headingCond(i) == 180
            count(5) = count(5) + 1;
            pointErrorByHeading{5}(count(5)) = abs(pointResponse(i) - correctPointAngle(i));
            if pointErrorByHeading{5}(count(5)) > 180
                pointErrorByHeading{5}(count(5)) = 360 - pointErrorByHeading{5}(count(5));
            end
        elseif headingCond(i) == 225
            count(6) = count(6) + 1;
            pointErrorByHeading{6}(count(6)) = abs(pointResponse(i) - correctPointAngle(i));
            if pointErrorByHeading{6}(count(6)) > 180
                pointErrorByHeading{6}(count(6)) = 360 - pointErrorByHeading{6}(count(6));
            end
        elseif headingCond(i) == 270
            count(7) = count(7) + 1;
            pointErrorByHeading{7}(count(7)) = abs(pointResponse(i) - correctPointAngle(i));
            if pointErrorByHeading{7}(count(7)) > 180
                pointErrorByHeading{7}(count(7)) = 360 - pointErrorByHeading{7}(count(7));
            end
        else
            count(8) = count(8) + 1;
            pointErrorByHeading{8}(count(8)) = abs(pointResponse(i) - correctPointAngle(i));
            if pointErrorByHeading{8}(count(8)) > 180
                pointErrorByHeading{8}(count(8)) = 360 - pointErrorByHeading{8}(count(8));
            end
        end
    end
    
        
    % Get participant averages
    for i=1:length(pointErrorByHeading)
        avgPointErrByHeading(i,p,z) = mean(pointErrorByHeading{i});       
    end
end
end

% Now get averages over participants
for z=1:numruns
for i=1:length(avgPointErrByHeading(:,1,z))
    avgPointErrHeadingCollapsed(i,z) = mean(avgPointErrByHeading(i,:,z));
end
end

% Now average accross runs
for i=1:length(avgPointErrHeadingCollapsed(:,1))
    avgPointErrRunsCollapsed(i) = mean(avgPointErrHeadingCollapsed(i,:));
end

% Plot results, first by heading condition
figure;
plot(headingConditions,avgPointErrRunsCollapsed, '-s', 'MarkerSize', 8);
xlabel('Heading Condition');
ylabel('Pointing Error');
xlim([-45 360]);
ylim([0 50]);
xticks(0:45:315);
yticks(0:10:50);

%% Organize observed data
Exp2matrix = readmatrix('KellyMcNamara2010_Exp1.xlsx');

Exp2matrixL1 = Exp2matrix(find(Exp2matrix(:,4) == 1),:);
Exp2matrixL2 = Exp2matrix(find(Exp2matrix(:,4) == 2),:);

Exp2matrixL1G1 = Exp2matrixL1(1:length(Exp2matrixL1)/2,:);
Exp2matrixL1G2 = Exp2matrixL1(1+length(Exp2matrixL1)/2:length(Exp2matrixL1),:);
Exp2matrixL2G1 = Exp2matrixL2(1:length(Exp2matrixL2)/2,:);
Exp2matrixL2G2 = Exp2matrixL2(1+length(Exp2matrixL2)/2:length(Exp2matrixL2),:);

if layout == 1 && refDir == 1
    
    % get pointing error by heading
    count = zeros(8,1);
    for i = 1:length(Exp2matrixL1G1(:,1))
        if Exp2matrixL1G1(i,8) == 0
            count(1) = count(1) + 1;
            pointErrorByHeadingObs{1}(count(1)) = Exp2matrixL1G1(i,12);
        elseif Exp2matrixL1G1(i,8) == 45
            count(2) = count(2) + 1;
            pointErrorByHeadingObs{2}(count(2)) = Exp2matrixL1G1(i,12);
        elseif Exp2matrixL1G1(i,8) == 90
            count(3) = count(3) + 1;
            pointErrorByHeadingObs{3}(count(3)) = Exp2matrixL1G1(i,12);
        elseif Exp2matrixL1G1(i,8) == 135
            count(4) = count(4) + 1;
            pointErrorByHeadingObs{4}(count(4)) = Exp2matrixL1G1(i,12);
        elseif Exp2matrixL1G1(i,8) == 180
            count(5) = count(5) + 1;
            pointErrorByHeadingObs{5}(count(5)) = Exp2matrixL1G1(i,12);
        elseif Exp2matrixL1G1(i,8) == 225
            count(6) = count(6) + 1;
            pointErrorByHeadingObs{6}(count(6)) = Exp2matrixL1G1(i,12);
        elseif Exp2matrixL1G1(i,8) == 270
            count(7) = count(7) + 1;
            pointErrorByHeadingObs{7}(count(7)) = Exp2matrixL1G1(i,12);
        else
            count(8) = count(8) + 1;
            pointErrorByHeadingObs{8}(count(8)) = Exp2matrixL1G1(i,12);
        end
    end
    
    % Now get averages
    for i=1:length(pointErrorByHeadingObs)
        avgPointErrorByHeadingObsL1G1(i) = mean(pointErrorByHeadingObs{i});
    end    

    HeadingEffectFit = corrcoef(avgPointErrorByHeadingObsL1G1,avgPointErrRunsCollapsed)
    
    % Plot heading effect
    figure;
    hold on
    plot(headingConditions,avgPointErrRunsCollapsed, '-sb', 'MarkerSize', 8);
    plot(headingConditions,avgPointErrorByHeadingObsL1G1, '--sb', 'MarkerSize', 8);
    hold off
    xlabel('Heading Condition');
    ylabel('Pointing Error');
    xlim([-45 360]);
    ylim([0 60]);
    xticks(0:45:315);
    yticks(0:10:50);
    
    % Now get coefficient distribution
    for z=1:numruns
        coeff = corrcoef(avgPointErrHeadingCollapsed(:,z),avgPointErrorByHeadingObsL1G1);
        coeffdist(z) = coeff(1,2);
    end
    
    coeffquantiles = quantile(coeffdist,[.1,.5,.9]);
    
elseif layout == 1 && refDir == 0
   
    % Get pointing error by heading direction
    count = zeros(8,1);
    for i = 1:length(Exp2matrixL1G2(:,1))
        if Exp2matrixL1G2(i,8) == 0
            count(1) = count(1) + 1;
            pointErrorByHeadingObs{1}(count(1)) = Exp2matrixL1G2(i,12);
        elseif Exp2matrixL1G2(i,8) == 45
            count(2) = count(2) + 1;
            pointErrorByHeadingObs{2}(count(2)) = Exp2matrixL1G2(i,12);
        elseif Exp2matrixL1G2(i,8) == 90
            count(3) = count(3) + 1;
            pointErrorByHeadingObs{3}(count(3)) = Exp2matrixL1G2(i,12);
        elseif Exp2matrixL1G2(i,8) == 135
            count(4) = count(4) + 1;
            pointErrorByHeadingObs{4}(count(4)) = Exp2matrixL1G2(i,12);
        elseif Exp2matrixL1G2(i,8) == 180
            count(5) = count(5) + 1;
            pointErrorByHeadingObs{5}(count(5)) = Exp2matrixL1G2(i,12);
        elseif Exp2matrixL1G2(i,8) == 225
            count(6) = count(6) + 1;
            pointErrorByHeadingObs{6}(count(6)) = Exp2matrixL1G2(i,12);
        elseif Exp2matrixL1G2(i,8) == 270
            count(7) = count(7) + 1;
            pointErrorByHeadingObs{7}(count(7)) = Exp2matrixL1G2(i,12);
        else
            count(8) = count(8) + 1;
            pointErrorByHeadingObs{8}(count(8)) = Exp2matrixL1G2(i,12);
        end
    end
    
    
    % Now get averages
    for i=1:length(pointErrorByHeadingObs)
        avgPointErrorByHeadingObsL1G2(i) = mean(pointErrorByHeadingObs{i});
    end
        
    HeadingEffectFit = corrcoef(avgPointErrorByHeadingObsL1G2,avgPointErrRunsCollapsed)
    
    % Plot heading effect
    figure;
    hold on
    plot(headingConditions,avgPointErrRunsCollapsed, '-sb', 'MarkerSize', 8);
    plot(headingConditions,avgPointErrorByHeadingObsL1G2, '--sb', 'MarkerSize', 8);
    hold off
    xlabel('Heading Condition');
    ylabel('Pointing Error');
    xlim([-45 360]);
    ylim([0 60]);
    xticks(0:45:315);
    yticks(0:10:50);
    
    for z=1:numruns
        coeff = corrcoef(avgPointErrHeadingCollapsed(:,z),avgPointErrorByHeadingObsL1G2);
        coeffdist(z) = coeff(1,2);
    end
    
    coeffquantiles = quantile(coeffdist,[.1,.5,.9]);
        
elseif layout == 0 && refDir == 1
    
    % get pointing error by heading
    count = zeros(8,1);
    for i = 1:length(Exp2matrixL2G1(:,1))
        if Exp2matrixL2G1(i,8) == 0
            count(1) = count(1) + 1;
            pointErrorByHeadingObs{1}(count(1)) = Exp2matrixL2G1(i,12);
        elseif Exp2matrixL2G1(i,8) == 45
            count(2) = count(2) + 1;
            pointErrorByHeadingObs{2}(count(2)) = Exp2matrixL2G1(i,12);
        elseif Exp2matrixL2G1(i,8) == 90
            count(3) = count(3) + 1;
            pointErrorByHeadingObs{3}(count(3)) = Exp2matrixL2G1(i,12);
        elseif Exp2matrixL2G1(i,8) == 135
            count(4) = count(4) + 1;
            pointErrorByHeadingObs{4}(count(4)) = Exp2matrixL2G1(i,12);
        elseif Exp2matrixL2G1(i,8) == 180
            count(5) = count(5) + 1;
            pointErrorByHeadingObs{5}(count(5)) = Exp2matrixL2G1(i,12);
        elseif Exp2matrixL2G1(i,8) == 225
            count(6) = count(6) + 1;
            pointErrorByHeadingObs{6}(count(6)) = Exp2matrixL2G1(i,12);
        elseif Exp2matrixL2G1(i,8) == 270
            count(7) = count(7) + 1;
            pointErrorByHeadingObs{7}(count(7)) = Exp2matrixL2G1(i,12);
        else
            count(8) = count(8) + 1;
            pointErrorByHeadingObs{8}(count(8)) = Exp2matrixL2G1(i,12);
        end
    end
    
    % Now get averages
    for i=1:length(pointErrorByHeadingObs)
        avgPointErrorByHeadingObsL2G1(i) = mean(pointErrorByHeadingObs{i});
    end    

    HeadingEffectFit = corrcoef(avgPointErrorByHeadingObsL2G1,avgPointErrRunsCollapsed)
    
    % Plot heading effect
    figure;
    hold on
    plot(headingConditions,avgPointErrRunsCollapsed, '-sb', 'MarkerSize', 8);
    plot(headingConditions,avgPointErrorByHeadingObsL2G1, '--sb', 'MarkerSize', 8);
    hold off
    xlabel('Heading Condition');
    ylabel('Pointing Error');
    xlim([-45 360]);
    ylim([0 60]);
    xticks(0:45:315);
    yticks(0:10:50);
    
    for z=1:numruns
        coeff = corrcoef(avgPointErrHeadingCollapsed(:,z),avgPointErrorByHeadingObsL2G1);
        coeffdist(z) = coeff(1,2);
    end
    
    coeffquantiles = quantile(coeffdist,[.1,.5,.9]);

elseif layout == 0 && refDir == 0
    
    % get pointing error by heading
    count = zeros(8,1);
    for i = 1:length(Exp2matrixL2G2(:,1))
        if Exp2matrixL2G2(i,8) == 0
            count(1) = count(1) + 1;
            pointErrorByHeadingObs{1}(count(1)) = Exp2matrixL2G2(i,12);
        elseif Exp2matrixL2G2(i,8) == 45
            count(2) = count(2) + 1;
            pointErrorByHeadingObs{2}(count(2)) = Exp2matrixL2G2(i,12);
        elseif Exp2matrixL2G2(i,8) == 90
            count(3) = count(3) + 1;
            pointErrorByHeadingObs{3}(count(3)) = Exp2matrixL2G2(i,12);
        elseif Exp2matrixL2G2(i,8) == 135
            count(4) = count(4) + 1;
            pointErrorByHeadingObs{4}(count(4)) = Exp2matrixL2G2(i,12);
        elseif Exp2matrixL2G2(i,8) == 180
            count(5) = count(5) + 1;
            pointErrorByHeadingObs{5}(count(5)) = Exp2matrixL2G2(i,12);
        elseif Exp2matrixL2G2(i,8) == 225
            count(6) = count(6) + 1;
            pointErrorByHeadingObs{6}(count(6)) = Exp2matrixL2G2(i,12);
        elseif Exp2matrixL2G2(i,8) == 270
            count(7) = count(7) + 1;
            pointErrorByHeadingObs{7}(count(7)) = Exp2matrixL2G2(i,12);
        else
            count(8) = count(8) + 1;
            pointErrorByHeadingObs{8}(count(8)) = Exp2matrixL2G2(i,12);
        end
    end
    
    % Now get averages
    for i=1:length(pointErrorByHeadingObs)
        avgPointErrorByHeadingObsL2G2(i) = mean(pointErrorByHeadingObs{i});
    end    

    HeadingEffectFit = corrcoef(avgPointErrorByHeadingObsL2G2,avgPointErrRunsCollapsed)
    
    % Plot heading effect
    figure;
    hold on
    plot(headingConditions,avgPointErrRunsCollapsed, '-sb', 'MarkerSize', 8);
    plot(headingConditions,avgPointErrorByHeadingObsL2G2, '--sb', 'MarkerSize', 8);
    hold off
    xlabel('Heading Condition');
    ylabel('Pointing Error');
    xlim([-45 360]);
    ylim([0 60]);
    xticks(0:45:315);
    yticks(0:10:50);
    
    for z=1:numruns
        coeff = corrcoef(avgPointErrHeadingCollapsed(:,z),avgPointErrorByHeadingObsL2G2);
        coeffdist(z) = coeff(1,2);
    end
    
    coeffquantiles = quantile(coeffdist,[.1,.5,.9]);
    
end

%% Plot just observed data

figure;
hold on
plot(headingConditions,avgPointErrorByHeadingObsL2G1, '-sb', 'MarkerSize', 8);
plot(headingConditions,avgPointErrorByHeadingObsL2G2, '--sr', 'MarkerSize', 8);
hold off
xlabel('Heading Condition');
ylabel('Pointing Error');
xlim([-45 360]);
ylim([0 70]);
xticks(0:45:315);
yticks(0:10:70);

figure;
hold on
plot(headingConditions,avgPointErrorByHeadingObsL1G1, '-sb', 'MarkerSize', 8);
plot(headingConditions,avgPointErrorByHeadingObsL1G2, '--sr', 'MarkerSize', 8);
hold off
xlabel('Heading Condition');
ylabel('Pointing Error');
xlim([-45 360]);
ylim([0 70]);
xticks(0:45:315);
yticks(0:10:70);
